from dis import Instruction

import torch
from openai.datalib.common import INSTRUCTIONS


def cross_entropy_loss(logits, targets, reduction='mean'):
    """
    不使用one_hot
    多分类交叉熵损失 (不考虑数值稳定、不考虑ignore_index)

    参数:
        logits: (N, C) torch.Tensor，未归一化分数
        targets: (N,) torch.LongTensor，类别标签 [0, C-1]
        reduction: 'mean' | 'sum' | 'none'
    返回:
        loss: 标量或 (N,) Tensor
    """
    # softmax
    exp_logits = torch.exp(logits)  # (N,C)
    # probs = exp_logits / torch.sum(exp_logits, dim=1, keepdim=True)  # (N,C)
    probs = torch.nn.functional.softmax(logits, dim=-1)  # (N,C)

    # 取正确类别概率
    N = logits.size(0)
    correct_probs = probs[torch.arange(N), targets]  # (N,)
    print(correct_probs)

    # 交叉熵: -log(p_y)
    loss = -torch.log(correct_probs)  # (N,)

    # reduction
    if reduction == 'none':
        return loss
    elif reduction == 'sum':
        return loss.sum()
    elif reduction == 'mean':
        return loss.mean()
    else:
        raise ValueError("reduction 必须是 'none' | 'sum' | 'mean'")


# logits: (batch=2, num_classes=3)
# logits = torch.tensor([[2.0, 1.0, 0.1],
#                        [0.5, 2.5, 0.3]], requires_grad=True)
# targets = torch.tensor([0, 1])   # 真实类别
#
# loss = cross_entropy_loss(logits, targets, reduction='mean')
# print("loss:", loss.item())
#
# # 反向传播
# loss.backward()
# print("logits.grad:", logits.grad)



def ce_one_hot(logits, targets, reduction='mean'):
    N,C = logits.shape
    pi = torch.nn.functional.softmax(logits, dim=-1)  # (N,C)
    print(pi)
    # 概率
    one_hot = torch.nn.functional.one_hot(targets, num_classes=C)
    print(one_hot * torch.log(pi))
    loss = - torch.sum(one_hot * torch.log(pi), dim=-1)  # (N,)
    print("loss", loss, loss.shape)
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()


logits = torch.tensor([[2.0, 1.0, 0.1],
                       [0.5, 2.5, 0.3]], requires_grad=True)
targets = torch.tensor([0, 1])   # 真实类别

loss = ce_one_hot(logits, targets, reduction='mean')
print("loss:", loss.item())






get GUI ScreenSPot
get user INSTRUCTIONS
history = []
inputs = processor(
    ScreenSpot
    INSTRUCTIONS
)
history.append(inputs)

steps = model.analysis(INSTRUCTIONS)

for step in steps:
    prediction = model.generate(
        step, ScreenSpot
    )
    # 执行不成功 重复操作
    #
    # A，B，C，D
    preModelJudge() --> rank prediction
    new ScreenSpot = Action(prediction)
    Judge(new ScreenSpot , step)  -->
    history.append(prediction)
    ScreenSpot = new ScreenSpot




